library(ggplot2)
library(tidyr)
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(rlang)
library(reshape2)
##
## Attaching package: 'reshape2'
## The following object is masked from 'package:tidyr':
##
## smiths
library(RColorBrewer)
library(mgcv)
## Loading required package: nlme
##
## Attaching package: 'nlme'
## The following object is masked from 'package:dplyr':
##
## collapse
## This is mgcv 1.9-1. For overview type 'help("mgcv-package")'.
library(caret)
## Loading required package: lattice
# Load data
load("dat1.RData")
# Define response variable
response_var <- "log_antibody"
# Define continuous variables
continuous_vars <- c(
"age", # Age of the participant
"height", # Height in cm
"weight", # Weight in kg
"bmi", # Body Mass Index
"SBP", # Systolic Blood Pressure
"LDL", # Low-Density Lipoprotein
"time" # Time measurement
)
# Define categorical variables
categorical_vars <- setdiff(names(dat1), c("id", response_var, continuous_vars))
categorical_vars <- categorical_vars[sapply(dat1[categorical_vars], function(x) is.numeric(x) || is.factor(x))]
# Convert categorical variables to factors
dat1[categorical_vars] <- lapply(dat1[categorical_vars], factor)
# Print variable types for verification
cat("Response variable:", response_var, "\n")
## Response variable: log_antibody
cat("Continuous variables:", paste(continuous_vars, collapse = ", "), "\n")
## Continuous variables: age, height, weight, bmi, SBP, LDL, time
cat("Categorical variables:", paste(categorical_vars, collapse = ", "), "\n")
## Categorical variables: gender, race, smoking, diabetes, hypertension
# Prepare data for plotting
long_df <- dat1 %>%
select(all_of(c(response_var, continuous_vars))) %>%
pivot_longer(cols = all_of(continuous_vars),
names_to = "Variable",
values_to = "Value")
# Plot relationships between all continuous variables and response variable
ggplot(long_df, aes(x = Value, y = .data[[response_var]])) +
geom_point(alpha = 0.3, color = "grey30") +
geom_smooth(method = "lm", se = FALSE, color = "red", linetype = "dashed") +
geom_smooth(method = "loess", se = FALSE, color = "blue") +
facet_wrap(~ Variable, scales = "free_x", ncol = 3) +
theme_minimal(base_size = 14) +
labs(title = paste("Linear vs Nonlinear Relationships with", response_var),
x = "Predictor",
y = paste("Log of", response_var))
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
# Create individual relationship plots for each continuous variable
for (var in continuous_vars) {
p <- ggplot(dat1, aes_string(x = var, y = response_var)) +
geom_point(alpha = 0.2, color = "grey30") +
geom_smooth(method = "lm", se = FALSE, color = "red", size = 1) +
geom_smooth(method = "loess", se = FALSE, color = "blue", linetype = "dashed") +
theme_bw(base_size = 16) +
labs(title = paste(response_var, "vs", var),
x = var,
y = paste("Log of", response_var))
print(p)
}
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
# Prepare numeric variables data
num_vars <- c(response_var, continuous_vars)
df_num <- dat1[, num_vars]
# Calculate correlation matrix
cor_matrix <- round(cor(df_num, use = "complete.obs"), 2)
# Convert to long format and create heatmap
cor_melted <- melt(cor_matrix)
ggplot(cor_melted, aes(x = Var1, y = Var2, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(low = "blue", high = "red", mid = "white",
midpoint = 0, limit = c(-1,1), space = "Lab",
name="Correlation") +
geom_text(aes(label = value), color = "black", size = 4) +
theme_minimal(base_size = 14) +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) +
labs(title = paste("Correlation Matrix of", response_var, "and Continuous Variables"),
x = "", y = "")
ctrl <- trainControl(method = "cv", number = 10)
# Fit the GAM model with selected predictors
gam_model <- gam(log_antibody ~ age + height + s(bmi) + SBP + LDL + s(time) +
gender + race + smoking + diabetes + hypertension,
data = dat1)
Avoid multicollinearity: Since BMI is a function of weight and height, so we only use height and drop weight.
set.seed(1)
model.glm <- train(
log_antibody ~ age + height + bmi+ SBP + LDL + time + gender + smoking + race + diabetes + hypertension,
data = dat1,
method = "glm",
trControl = ctrl
)
model.glmnet <- train(
log_antibody ~ age + height + bmi + SBP + LDL + time +
gender + smoking + race + diabetes + hypertension,
data = dat1,
method = "glmnet",
trControl = ctrl,
tuneLength = 10
)
set.seed(1)
model.gam <- train(
log_antibody ~ age + height + bmi+ SBP + LDL + time + gender + smoking + race + diabetes + hypertension,
data = dat1,
method = "gam",
trControl = ctrl
)
set.seed(1)
model.mars <- train(
log_antibody ~ age + height + bmi+ SBP + LDL + time + gender + smoking + race + diabetes + hypertension,
data = dat1,
method = "earth",
tuneGrid = expand.grid(degree = 1:4, nprune = 2:20),
trControl = ctrl
)
## Loading required package: earth
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
plot(model.mars)
model_list <- list(
GLM = model.glm,
GLMNET = model.glmnet,
GAM = model.gam,
MARS = model.mars
)
res <- resamples(model_list)
summary(res)
##
## Call:
## summary.resamples(object = res)
##
## Models: GLM, GLMNET, GAM, MARS
## Number of resamples: 10
##
## MAE
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GLM 0.4260649 0.4346680 0.4394515 0.4405523 0.4455348 0.4644864 0
## GLMNET 0.4245508 0.4295505 0.4386702 0.4403408 0.4488301 0.4629820 0
## GAM 0.4042323 0.4194560 0.4239197 0.4228072 0.4275172 0.4421721 0
## MARS 0.4033817 0.4179151 0.4234065 0.4221353 0.4271785 0.4418075 0
##
## RMSE
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GLM 0.5396656 0.5432773 0.5518663 0.5528841 0.5586044 0.5754816 0
## GLMNET 0.5335518 0.5401949 0.5490274 0.5525841 0.5630110 0.5797864 0
## GAM 0.5104325 0.5194098 0.5286186 0.5285509 0.5333212 0.5524970 0
## MARS 0.5086158 0.5180140 0.5277205 0.5276088 0.5322340 0.5523110 0
##
## Rsquared
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GLM 0.0908225 0.1321875 0.1427272 0.1416447 0.1616694 0.1766897 0
## GLMNET 0.1010237 0.1239272 0.1413950 0.1427619 0.1642361 0.1978363 0
## GAM 0.1631264 0.1988698 0.2126252 0.2154307 0.2352304 0.2640493 0
## MARS 0.1645263 0.2016386 0.2166463 0.2183207 0.2395080 0.2656206 0
bwplot(res)
load("dat2.RData")
# Match factor levels to those in dat1
for (var in categorical_vars) {
dat2[[var]] <- factor(dat2[[var]], levels = levels(dat1[[var]]))
}
rmse <- function(actual, predicted) {
sqrt(mean((actual - predicted)^2, na.rm = TRUE))
}
results <- data.frame(Model = character(), RMSE = numeric(), R2 = numeric())
# Loop through models
for (model_name in names(model_list)) {
model <- model_list[[model_name]]
preds <- predict(model, newdata = dat2)
rmse_val <- rmse(dat2$log_antibody, preds)
results <- rbind(results, data.frame(Model = model_name, RMSE = rmse_val))
}
print(results)
## Model RMSE
## 1 GLM 0.5707670
## 2 GLMNET 0.5732998
## 3 GAM 0.5704860
## 4 MARS 0.5327718
for (model_name in names(model_list)) {
dat2[[paste0("pred_", model_name)]] <- predict(model_list[[model_name]], newdata = dat2)
}
# Example plot for MARS
ggplot(dat2, aes(x = log_antibody, y = pred_MARS)) +
geom_point(alpha = 0.3, color = "blue") +
geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "red") +
theme_minimal() +
labs(title = "MARS: Observed vs Predicted on Test Data",
x = "Observed log_antibody", y = "Predicted")